from tqdm import tqdm
import torch
import torch.nn.functional as F
import os, pickle
import numpy as np
import torch.utils.data as data
from torch.utils.data import DataLoader


class Analysis() :

    def __init__(self, rootdir, args, overwrite):
        self.args = args
        self.rootdir = rootdir
        self.overwrite=overwrite
        self._fig_size=4
        self._dpi=200
        self.exps = [args.exp]

    def _file_name(self, postfix, rootdir=None):
        raise NotImplementedError("Needs a _file_name function")

    def _exist(self, attack):
        return os.path.exists(self._file_name(attack) + ".pkl")

    def _load(self, attack):
        fname = self._file_name(attack);
        with open(fname + ".pkl", 'rb') as f :
            result = pickle.load(f)
        return result

    def _save(self, attack, result_dict):
        fname = self._file_name(attack);
        with open(fname + ".pkl", 'wb') as f:
            pickle.dump(result_dict, f)

    def __add__(self, other):
        if type(self) is not type(other) :
            raise Exception("Summation : Check anaylsis types")
        if self.results_exp.keys() != other.results_exp.keys():
            raise Warning("Adding analysis class with different dictionary keys")
        self.results_all.update(other.results_all)
        self.exps.extend(other.exps)
        return self

    def __radd__(self, other) :
        if other == 0 :
            return self

class Adv_Analysis(Analysis) :

    def __init__(self, model, dloader, rootdir, args, overwrite) :
        super().__init__(rootdir=rootdir, args=args,
                         overwrite=overwrite)
        self.model = model
        self.dloader =dloader
        self.results_all = {args.exp : dict()}
        self.results_exp = self.results_all[args.exp]
        self.results_exp["orig"] = \
            self._collect_data(model, dloader, args);

    @torch.no_grad()
    def _collect_data(self, model, loader, args, desc=None):
        model.eval()
        data_dict = {"ys": [],
                    "y_probs": [],  ## if y is given as confidence
                    "logits": []}
        for x, y in tqdm(loader, desc="Collecting x=>y outputs for Adv Analysis" if desc is None else desc):
            x = x.cuda(non_blocking=True)
            # y = y.cuda(non_blocking=True)
            if y.dtype == torch.long:
                y = F.one_hot(y, num_classes=args.num_labels).float()
            logits = model(x)
            data_dict["ys"].extend(y.argmax(1).data.numpy())
            data_dict["y_probs"].extend(
                torch.stack([y[idx, y1] for idx, y1 in enumerate(y.argmax(1))])
                .data.numpy())
            data_dict["logits"].extend(logits.cpu().data.numpy())
        return data_dict

    def get_xys(self, model=None, loader=None, args=None, desc=None):
        return self._collect_data(model = self.model if model is None else model,
              loader = self.loader if loader is None else loader,
              args = self.args if args is None else args,
              desc=desc)

    def _file_name(self, postfix, rootdir=None):
        raise NotImplementedError("Needs a _file_name function")
